{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 2. Skip-gram with negative sampling"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "I recommend you take a look at these material first."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture3.pdf\n",
    "* http://papers.nips.cc/paper/5021-distributed-representations-of-words-and-phrases-and-their-compositionality.pdf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.autograd import Variable\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "import nltk\n",
    "import random\n",
    "import numpy as np\n",
    "from collections import Counter\n",
    "flatten = lambda l: [item for sublist in l for item in sublist]\n",
    "random.seed(1024)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.3.0.post4\n",
      "3.2.4\n"
     ]
    }
   ],
   "source": [
    "print(torch.__version__)\n",
    "print(nltk.__version__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "USE_CUDA = torch.cuda.is_available()\n",
    "gpus = [0]\n",
    "torch.cuda.set_device(gpus[0])\n",
    "\n",
    "FloatTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor\n",
    "LongTensor = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor\n",
    "ByteTensor = torch.cuda.ByteTensor if USE_CUDA else torch.ByteTensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def getBatch(batch_size, train_data):\n",
    "    random.shuffle(train_data)\n",
    "    sindex = 0\n",
    "    eindex = batch_size\n",
    "    while eindex < len(train_data):\n",
    "        batch = train_data[sindex: eindex]\n",
    "        temp = eindex\n",
    "        eindex = eindex + batch_size\n",
    "        sindex = temp\n",
    "        yield batch\n",
    "    \n",
    "    if eindex >= len(train_data):\n",
    "        batch = train_data[sindex:]\n",
    "        yield batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def prepare_sequence(seq, word2index):\n",
    "    idxs = list(map(lambda w: word2index[w] if word2index.get(w) is not None else word2index[\"<UNK>\"], seq))\n",
    "    return Variable(LongTensor(idxs))\n",
    "\n",
    "def prepare_word(word, word2index):\n",
    "    return Variable(LongTensor([word2index[word]]) if word2index.get(word) is not None else LongTensor([word2index[\"<UNK>\"]]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data load and Preprocessing "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "corpus = list(nltk.corpus.gutenberg.sents('melville-moby_dick.txt'))[:500]\n",
    "corpus = [[word.lower() for word in sent] for sent in corpus]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Exclude sparse words "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "word_count = Counter(flatten(corpus))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "MIN_COUNT = 3\n",
    "exclude = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "for w, c in word_count.items():\n",
    "    if c < MIN_COUNT:\n",
    "        exclude.append(w)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Prepare train data "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "vocab = list(set(flatten(corpus)) - set(exclude))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "word2index = {}\n",
    "for vo in vocab:\n",
    "    if word2index.get(vo) is None:\n",
    "        word2index[vo] = len(word2index)\n",
    "        \n",
    "index2word = {v:k for k, v in word2index.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "WINDOW_SIZE = 5\n",
    "windows =  flatten([list(nltk.ngrams(['<DUMMY>'] * WINDOW_SIZE + c + ['<DUMMY>'] * WINDOW_SIZE, WINDOW_SIZE * 2 + 1)) for c in corpus])\n",
    "\n",
    "train_data = []\n",
    "\n",
    "for window in windows:\n",
    "    for i in range(WINDOW_SIZE * 2 + 1):\n",
    "        if window[i] in exclude or window[WINDOW_SIZE] in exclude: \n",
    "            continue # min_count\n",
    "        if i == WINDOW_SIZE or window[i] == '<DUMMY>': \n",
    "            continue\n",
    "        train_data.append((window[WINDOW_SIZE], window[i]))\n",
    "\n",
    "X_p = []\n",
    "y_p = []\n",
    "\n",
    "for tr in train_data:\n",
    "    X_p.append(prepare_word(tr[0], word2index).view(1, -1))\n",
    "    y_p.append(prepare_word(tr[1], word2index).view(1, -1))\n",
    "    \n",
    "train_data = list(zip(X_p, y_p))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "50242"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(train_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Build Unigram Distribution**0.75 "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "$$P(w)=U(w)^{3/4}/Z$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "Z = 0.001"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "word_count = Counter(flatten(corpus))\n",
    "num_total_words = sum([c for w, c in word_count.items() if w not in exclude])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "unigram_table = []\n",
    "\n",
    "for vo in vocab:\n",
    "    unigram_table.extend([vo] * int(((word_count[vo]/num_total_words)**0.75)/Z))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "478 3500\n"
     ]
    }
   ],
   "source": [
    "print(len(vocab), len(unigram_table))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Negative Sampling "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def negative_sampling(targets, unigram_table, k):\n",
    "    batch_size = targets.size(0)\n",
    "    neg_samples = []\n",
    "    for i in range(batch_size):\n",
    "        nsample = []\n",
    "        target_index = targets[i].data.cpu().tolist()[0] if USE_CUDA else targets[i].data.tolist()[0]\n",
    "        while len(nsample) < k: # num of sampling\n",
    "            neg = random.choice(unigram_table)\n",
    "            if word2index[neg] == target_index:\n",
    "                continue\n",
    "            nsample.append(neg)\n",
    "        neg_samples.append(prepare_sequence(nsample, word2index).view(1, -1))\n",
    "    \n",
    "    return torch.cat(neg_samples)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Modeling "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"../images/02.skipgram-objective.png\">\n",
    "<center>borrowed image from http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture3.pdf</center>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class SkipgramNegSampling(nn.Module):\n",
    "    \n",
    "    def __init__(self, vocab_size, projection_dim):\n",
    "        super(SkipgramNegSampling, self).__init__()\n",
    "        self.embedding_v = nn.Embedding(vocab_size, projection_dim) # center embedding\n",
    "        self.embedding_u = nn.Embedding(vocab_size, projection_dim) # out embedding\n",
    "        self.logsigmoid = nn.LogSigmoid()\n",
    "                \n",
    "        initrange = (2.0 / (vocab_size + projection_dim))**0.5 # Xavier init\n",
    "        self.embedding_v.weight.data.uniform_(-initrange, initrange) # init\n",
    "        self.embedding_u.weight.data.uniform_(-0.0, 0.0) # init\n",
    "        \n",
    "    def forward(self, center_words, target_words, negative_words):\n",
    "        center_embeds = self.embedding_v(center_words) # B x 1 x D\n",
    "        target_embeds = self.embedding_u(target_words) # B x 1 x D\n",
    "        \n",
    "        neg_embeds = -self.embedding_u(negative_words) # B x K x D\n",
    "        \n",
    "        positive_score = target_embeds.bmm(center_embeds.transpose(1, 2)).squeeze(2) # Bx1\n",
    "        negative_score = torch.sum(neg_embeds.bmm(center_embeds.transpose(1, 2)).squeeze(2), 1).view(negs.size(0), -1) # BxK -> Bx1\n",
    "        \n",
    "        loss = self.logsigmoid(positive_score) + self.logsigmoid(negative_score)\n",
    "        \n",
    "        return -torch.mean(loss)\n",
    "    \n",
    "    def prediction(self, inputs):\n",
    "        embeds = self.embedding_v(inputs)\n",
    "        \n",
    "        return embeds"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "EMBEDDING_SIZE = 30 \n",
    "BATCH_SIZE = 256\n",
    "EPOCH = 100\n",
    "NEG = 10 # Num of Negative Sampling"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "losses = []\n",
    "model = SkipgramNegSampling(len(word2index), EMBEDDING_SIZE)\n",
    "if USE_CUDA:\n",
    "    model = model.cuda()\n",
    "optimizer = optim.Adam(model.parameters(), lr=0.001)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch : 0, mean_loss : 1.06\n",
      "Epoch : 10, mean_loss : 0.86\n",
      "Epoch : 20, mean_loss : 0.79\n",
      "Epoch : 30, mean_loss : 0.74\n",
      "Epoch : 40, mean_loss : 0.71\n",
      "Epoch : 50, mean_loss : 0.69\n",
      "Epoch : 60, mean_loss : 0.67\n",
      "Epoch : 70, mean_loss : 0.65\n",
      "Epoch : 80, mean_loss : 0.64\n",
      "Epoch : 90, mean_loss : 0.63\n"
     ]
    }
   ],
   "source": [
    "for epoch in range(EPOCH):\n",
    "    for i,batch in enumerate(getBatch(BATCH_SIZE, train_data)):\n",
    "        \n",
    "        inputs, targets = zip(*batch)\n",
    "        \n",
    "        inputs = torch.cat(inputs) # B x 1\n",
    "        targets = torch.cat(targets) # B x 1\n",
    "        negs = negative_sampling(targets, unigram_table, NEG)\n",
    "        model.zero_grad()\n",
    "\n",
    "        loss = model(inputs, targets, negs)\n",
    "        \n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "    \n",
    "        losses.append(loss.data.tolist()[0])\n",
    "    if epoch % 10 == 0:\n",
    "        print(\"Epoch : %d, mean_loss : %.02f\" % (epoch, np.mean(losses)))\n",
    "        losses = []"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def word_similarity(target, vocab):\n",
    "    if USE_CUDA:\n",
    "        target_V = model.prediction(prepare_word(target, word2index))\n",
    "    else:\n",
    "        target_V = model.prediction(prepare_word(target, word2index))\n",
    "    similarities = []\n",
    "    for i in range(len(vocab)):\n",
    "        if vocab[i] == target: \n",
    "            continue\n",
    "        \n",
    "        if USE_CUDA:\n",
    "            vector = model.prediction(prepare_word(list(vocab)[i], word2index))\n",
    "        else:\n",
    "            vector = model.prediction(prepare_word(list(vocab)[i], word2index))\n",
    "        \n",
    "        cosine_sim = F.cosine_similarity(target_V, vector).data.tolist()[0]\n",
    "        similarities.append([vocab[i], cosine_sim])\n",
    "    return sorted(similarities, key=lambda x: x[1], reverse=True)[:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 212,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'passengers'"
      ]
     },
     "execution_count": 212,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test = random.choice(list(vocab))\n",
    "test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 213,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[['am', 0.7353377342224121],\n",
       " ['passenger', 0.7154150605201721],\n",
       " ['cook', 0.6829826831817627],\n",
       " ['new', 0.6648461818695068],\n",
       " ['bedford', 0.6283411383628845],\n",
       " ['besides', 0.5972960591316223],\n",
       " ['themselves', 0.5964340567588806],\n",
       " ['grow', 0.5957046151161194],\n",
       " ['tell', 0.5952941179275513],\n",
       " ['get', 0.5943044424057007]]"
      ]
     },
     "execution_count": 213,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "word_similarity(test, vocab)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
